#!/usr/bin/env python3

import torch

def token_cossim(
    x: torch.Tensor,
    vs_clstoken: bool = False,
) -> torch.Tensor:
    r"""Calculate the cosine similarity between pairs of tokens.

    Args:
        x (torch.Tensor): A tensor of tokens. Expected size is (Batch,Layer,numToken,dim)
        vs_clstoken (bool, optional):  If True, calculate similarity between the cls-token and other tokens.Defaults to False.

    Returns:
        torch.Tensor: A tensor representing the cosine similarity.
    """

    # assertion
    assert x.dim() == 4, "Expected size is (Batch,Layer,numToken,dim)"

    # L2norm
    x = x / torch.norm(x, 2, dim=-1, keepdim=True)

    if vs_clstoken:
        similarity = torch.einsum("BLnd,BLd->BLn", x[:, :, 1:], x[:, :, 0])

    else:
        # remove cls_token(only ViT)
        if x.size(2) % 2 == 1:
            x = x[:, :, 1:]

        _, _, n, _ = x.size()
        indices = torch.triu_indices(n, n, offset=1)

        similarity = torch.einsum(
            "BLnd,BLnd->BLn", x[:, :, indices[0]], x[:, :, indices[1]]
        )

    return similarity
